# -*- coding: utf-8 -*-
#
# infer_score.sh
#
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

#!/bin/bash

# All pretrained model can be generated by any training scripts in 
# multi_gpu.sh, multi_cpu.sh or dist_train.sh

# Inference using TransE_l1 pretrained model
dglke_predict --model_path ckpts/TransE_l1_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5
dglke_predict --model_path ckpts/TransE_l1_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast head
dglke_predict --model_path ckpts/TransE_l1_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast rel
dglke_predict --model_path ckpts/TransE_l1_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast tail

# Inference using TransE_l2 pretrained model using logsigmoid
dglke_predict --model_path ckpts/TransE_l2_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --score_func logsigmoid
dglke_predict --model_path ckpts/TransE_l2_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast head --score_func logsigmoid
dglke_predict --model_path ckpts/TransE_l2_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast rel --score_func logsigmoid
dglke_predict --model_path ckpts/TransE_l2_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 5 --bcast tail --score_func logsigmoid

# Inference using DistMult pretrained model, using the whole tail set as tail
dglke_predict --model_path ckpts/DistMult_wn18_0/ --format h_r_* --data_files head.list rel.list --topK 5 --gpu 0
dglke_predict --model_path ckpts/DistMult_wn18_0/ --format h_r_* --data_files head.list rel.list --topK 5 --bcast head --gpu 0
dglke_predict --model_path ckpts/DistMult_wn18_0/ --format h_r_* --data_files head.list rel.list --topK 5 --bcast rel --gpu 0
dglke_predict --model_path ckpts/DistMult_wn18_0/ --format h_r_* --data_files head.list rel.list --topK 5 --bcast tail --gpu 0

# Inference using ComplEx pretrained model, using the whole relation set as relation
dglke_predict --model_path ckpts/ComplEx_wn18_0/ --format h_*_t --data_files head.list tail.list --topK 15 --gpu 0
dglke_predict --model_path ckpts/ComplEx_wn18_0/ --format h_*_t --data_files head.list tail.list --topK 15 --bcast head --gpu 0
dglke_predict --model_path ckpts/ComplEx_wn18_0/ --format h_*_t --data_files head.list tail.list --topK 15 --bcast rel --gpu 0
dglke_predict --model_path ckpts/ComplEx_wn18_0/ --format h_*_t --data_files head.list tail.list --topK 15 --bcast tail --gpu 0

# Inference using RESCAL pretrained model, using the whole head set as head
dglke_predict --model_path ckpts/RESCAL_wn18_0/ --format *_r_t --data_files rel.list tail.list --topK 15 --gpu 0
dglke_predict --model_path ckpts/RESCAL_wn18_0/ --format *_r_t --data_files rel.list tail.list --topK 15 --bcast head --gpu 0
dglke_predict --model_path ckpts/RESCAL_wn18_0/ --format *_r_t --data_files rel.list tail.list --topK 15 --bcast rel --gpu 0
dglke_predict --model_path ckpts/RESCAL_wn18_0/ --format *_r_t --data_files rel.list tail.list --topK 15 --bcast tail --gpu 0

# Inference using RotatE pretrained model
dglke_predict --model_path ckpts/RotatE_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 15
dglke_predict --model_path ckpts/RotatE_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 15 --bcast head
dglke_predict --model_path ckpts/RotatE_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 15 --bcast rel
dglke_predict --model_path ckpts/RotatE_wn18_0/ --format h_r_t --data_files head.list rel.list tail.list --topK 15 --bcast tail
